{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "---\n",
    "description: Continual Learning Algorithms Prototyping Made Easy\n",
    "---\n",
    "# Training\n",
    "\n",
    "Welcome to the \"_Training_\" tutorial of the \"_From Zero to Hero_\" series. In this part we will present the functionalities offered by the `training` module.\n",
    "\n",
    "First, let's install Avalanche. You can skip this step if you have installed it already."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "!pip install avalanche-lib==0.5"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 💪 The Training Module\n",
    "\n",
    "The `training` module in _Avalanche_ is designed with modularity in mind. Its main goals are to:\n",
    "\n",
    "1. provide a set of popular **continual learning baselines** that can be easily used to run experimental comparisons;\n",
    "2. provide simple abstractions to **create and run your own strategy** as efficiently and easily as possible starting from a couple of basic building blocks we already prepared for you.\n",
    "\n",
    "At the moment, the `training` module includes three main components:\n",
    "\n",
    "* **Templates**: these are high level abstractions used as a starting point to define the actual strategies. The templates contain already implemented basic utilities and functionalities shared by a group of strategies (e.g. the `BaseSGDTemplate` contains all the implemented methods to deal with strategies based on SGD).\n",
    "* **Strategies**: these are popular baselines already implemented for you which you can use for comparisons or as base classes to define a custom strategy.\n",
    "* **Plugins**: these are classes that allow adding some specific behavior to your own strategy. The plugin system allows defining reusable components which can be easily combined (e.g. a replay strategy, a regularization strategy). They are also used to automatically manage logging and evaluation.\n",
    "\n",
    "Keep in mind that many Avalanche components are independent of Avalanche strategies. If you already have your own strategy which does not use Avalanche, you can use Avalanche's benchmarks, models, data loaders, and metrics without ever looking at Avalanche's strategies!\n",
    "\n",
    "## 📈 How to Use Strategies & Plugins\n",
    "\n",
    "If you want to compare your strategy with other classic continual learning algorithms or baselines, in _Avalanche_ you can instantiate a strategy with a couple of lines of code.\n",
    "\n",
    "### Strategy Instantiation\n",
    "Most strategies require only 3 mandatory arguments:\n",
    "- **model**: this must be a `torch.nn.Module`.\n",
    "- **optimizer**: `torch.optim.Optimizer` already initialized on your `model`.\n",
    "- **loss**: a loss function such as those in `torch.nn.functional`.\n",
    "\n",
    "Additional arguments are optional and allow you to customize training (batch size, number of epochs, ...) or strategy-specific parameters (memory size, regularization strength, ...)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from torch.optim import SGD\n",
    "from torch.nn import CrossEntropyLoss\n",
    "from avalanche.models import SimpleMLP\n",
    "from avalanche.training.supervised import Naive, CWRStar, Replay, GDumb, Cumulative, LwF, GEM, AGEM, EWC  # and many more!\n",
    "\n",
    "model = SimpleMLP(num_classes=10)\n",
    "optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)\n",
    "criterion = CrossEntropyLoss()\n",
    "cl_strategy = Naive(\n",
    "    model, optimizer, criterion,\n",
    "    train_mb_size=100, train_epochs=4, eval_mb_size=100\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Training & Evaluation\n",
    "\n",
    "Each strategy object offers two main methods: `train` and `eval`. Both of them, accept either a _single experience_(`Experience`) or a _list of them_, for maximum flexibility.\n",
    "\n",
    "We can train the model continually by iterating over the `train_stream` provided by the scenario."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting experiment...\n",
      "Start of experience:  0\n",
      "Current Classes:  [5, 6]\n",
      "-- >> Start of training phase << --\n",
      "100%|██████████| 114/114 [00:00<00:00, 183.70it/s]\n",
      "Epoch 0 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.3980\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8935\n",
      "100%|██████████| 114/114 [00:00<00:00, 182.57it/s]\n",
      "Epoch 1 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.1019\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9683\n",
      "100%|██████████| 114/114 [00:00<00:00, 178.79it/s]\n",
      "Epoch 2 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.0862\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9737\n",
      "100%|██████████| 114/114 [00:00<00:00, 188.97it/s]\n",
      "Epoch 3 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.0778\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9742\n",
      "-- >> End of training phase << --\n",
      "Training completed\n",
      "Computing accuracy on the whole test set\n",
      "-- >> Start of eval phase << --\n",
      "-- Starting eval on experience 0 (Task 0) from test stream --\n",
      "100%|██████████| 19/19 [00:00<00:00, 232.93it/s]\n",
      "> Eval on experience 0 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp000 = 0.0650\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.9805\n",
      "-- Starting eval on experience 1 (Task 0) from test stream --\n",
      "100%|██████████| 22/22 [00:00<00:00, 249.61it/s]\n",
      "> Eval on experience 1 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp001 = 7.9174\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.0000\n",
      "-- Starting eval on experience 2 (Task 0) from test stream --\n",
      "100%|██████████| 20/20 [00:00<00:00, 221.20it/s]\n",
      "> Eval on experience 2 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp002 = 10.7715\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp002 = 0.0000\n",
      "-- Starting eval on experience 3 (Task 0) from test stream --\n",
      "100%|██████████| 21/21 [00:00<00:00, 266.36it/s]\n",
      "> Eval on experience 3 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp003 = 10.5082\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp003 = 0.0000\n",
      "-- Starting eval on experience 4 (Task 0) from test stream --\n",
      "100%|██████████| 21/21 [00:00<00:00, 239.92it/s]\n",
      "> Eval on experience 4 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp004 = 8.9207\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp004 = 0.0000\n",
      "-- >> End of eval phase << --\n",
      "\tLoss_Stream/eval_phase/test_stream/Task000 = 7.7472\n",
      "\tTop1_Acc_Stream/eval_phase/test_stream/Task000 = 0.1814\n",
      "Start of experience:  1\n",
      "Current Classes:  [1, 2]\n",
      "-- >> Start of training phase << --\n",
      "100%|██████████| 127/127 [00:00<00:00, 182.92it/s]\n",
      "Epoch 0 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.6103\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8739\n",
      "100%|██████████| 127/127 [00:00<00:00, 192.69it/s]\n",
      "Epoch 1 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.0654\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9814\n",
      "100%|██████████| 127/127 [00:00<00:00, 188.60it/s]\n",
      "Epoch 2 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.0519\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9860\n",
      "100%|██████████| 127/127 [00:00<00:00, 190.68it/s]\n",
      "Epoch 3 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.0448\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9869\n",
      "-- >> End of training phase << --\n",
      "Training completed\n",
      "Computing accuracy on the whole test set\n",
      "-- >> Start of eval phase << --\n",
      "-- Starting eval on experience 0 (Task 0) from test stream --\n",
      "100%|██████████| 19/19 [00:00<00:00, 265.36it/s]\n",
      "> Eval on experience 0 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp000 = 7.5604\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.0000\n",
      "-- Starting eval on experience 1 (Task 0) from test stream --\n",
      "100%|██████████| 22/22 [00:00<00:00, 246.63it/s]\n",
      "> Eval on experience 1 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp001 = 0.0316\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.9917\n",
      "-- Starting eval on experience 2 (Task 0) from test stream --\n",
      "100%|██████████| 20/20 [00:00<00:00, 266.38it/s]\n",
      "> Eval on experience 2 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp002 = 10.8943\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp002 = 0.0000\n",
      "-- Starting eval on experience 3 (Task 0) from test stream --\n",
      "100%|██████████| 21/21 [00:00<00:00, 270.88it/s]\n",
      "> Eval on experience 3 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp003 = 9.3077\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp003 = 0.0000\n",
      "-- Starting eval on experience 4 (Task 0) from test stream --\n",
      "100%|██████████| 21/21 [00:00<00:00, 241.89it/s]\n",
      "> Eval on experience 4 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp004 = 8.6138\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp004 = 0.0000\n",
      "-- >> End of eval phase << --\n",
      "\tLoss_Stream/eval_phase/test_stream/Task000 = 7.1449\n",
      "\tTop1_Acc_Stream/eval_phase/test_stream/Task000 = 0.2149\n",
      "Start of experience:  2\n",
      "Current Classes:  [0, 8]\n",
      "-- >> Start of training phase << --\n",
      "100%|██████████| 118/118 [00:00<00:00, 197.27it/s]\n",
      "Epoch 0 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.8278\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8503\n",
      "100%|██████████| 118/118 [00:00<00:00, 195.27it/s]\n",
      "Epoch 1 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.0593\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9834\n",
      "100%|██████████| 118/118 [00:00<00:00, 190.97it/s]\n",
      "Epoch 2 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.0469\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9864\n",
      "100%|██████████| 118/118 [00:00<00:00, 185.49it/s]\n",
      "Epoch 3 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.0400\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9882\n",
      "-- >> End of training phase << --\n",
      "Training completed\n",
      "Computing accuracy on the whole test set\n",
      "-- >> Start of eval phase << --\n",
      "-- Starting eval on experience 0 (Task 0) from test stream --\n",
      "100%|██████████| 19/19 [00:00<00:00, 264.18it/s]\n",
      "> Eval on experience 0 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp000 = 6.7370\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.0000\n",
      "-- Starting eval on experience 1 (Task 0) from test stream --\n",
      "100%|██████████| 22/22 [00:00<00:00, 231.28it/s]\n",
      "> Eval on experience 1 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp001 = 5.8661\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.0000\n",
      "-- Starting eval on experience 2 (Task 0) from test stream --\n",
      "100%|██████████| 20/20 [00:00<00:00, 254.30it/s]\n",
      "> Eval on experience 2 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp002 = 0.0296\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp002 = 0.9928\n",
      "-- Starting eval on experience 3 (Task 0) from test stream --\n",
      "100%|██████████| 21/21 [00:00<00:00, 239.80it/s]\n",
      "> Eval on experience 3 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp003 = 8.9566\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp003 = 0.0000\n",
      "-- Starting eval on experience 4 (Task 0) from test stream --\n",
      "100%|██████████| 21/21 [00:00<00:00, 238.97it/s]\n",
      "> Eval on experience 4 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp004 = 7.6316\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp004 = 0.0000\n",
      "-- >> End of eval phase << --\n",
      "\tLoss_Stream/eval_phase/test_stream/Task000 = 5.8656\n",
      "\tTop1_Acc_Stream/eval_phase/test_stream/Task000 = 0.1940\n",
      "Start of experience:  3\n",
      "Current Classes:  [9, 3]\n",
      "-- >> Start of training phase << --\n",
      "100%|██████████| 121/121 [00:00<00:00, 184.16it/s]\n",
      "Epoch 0 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.8425\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8303\n",
      "100%|██████████| 121/121 [00:00<00:00, 174.39it/s]\n",
      "Epoch 1 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.1153\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9626\n",
      "100%|██████████| 121/121 [00:00<00:00, 180.35it/s]\n",
      "Epoch 2 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.0947\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9702\n",
      "100%|██████████| 121/121 [00:00<00:00, 178.79it/s]\n",
      "Epoch 3 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.0803\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9753\n",
      "-- >> End of training phase << --\n",
      "Training completed\n",
      "Computing accuracy on the whole test set\n",
      "-- >> Start of eval phase << --\n",
      "-- Starting eval on experience 0 (Task 0) from test stream --\n",
      "100%|██████████| 19/19 [00:00<00:00, 254.48it/s]\n",
      "> Eval on experience 0 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp000 = 7.6022\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.0000\n",
      "-- Starting eval on experience 1 (Task 0) from test stream --\n",
      "100%|██████████| 22/22 [00:00<00:00, 228.73it/s]\n",
      "> Eval on experience 1 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp001 = 6.4503\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.0000\n",
      "-- Starting eval on experience 2 (Task 0) from test stream --\n",
      "100%|██████████| 20/20 [00:00<00:00, 257.75it/s]\n",
      "> Eval on experience 2 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp002 = 7.2968\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp002 = 0.0010\n",
      "-- Starting eval on experience 3 (Task 0) from test stream --\n",
      "100%|██████████| 21/21 [00:00<00:00, 230.76it/s]\n",
      "> Eval on experience 3 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp003 = 0.0619\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp003 = 0.9802\n",
      "-- Starting eval on experience 4 (Task 0) from test stream --\n",
      "100%|██████████| 21/21 [00:00<00:00, 254.90it/s]\n",
      "> Eval on experience 4 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp004 = 9.5913\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp004 = 0.0000\n",
      "-- >> End of eval phase << --\n",
      "\tLoss_Stream/eval_phase/test_stream/Task000 = 6.1703\n",
      "\tTop1_Acc_Stream/eval_phase/test_stream/Task000 = 0.1981\n",
      "Start of experience:  4\n",
      "Current Classes:  [4, 7]\n",
      "-- >> Start of training phase << --\n",
      "100%|██████████| 122/122 [00:00<00:00, 181.28it/s]\n",
      "Epoch 0 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.9374\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8095\n",
      "100%|██████████| 122/122 [00:00<00:00, 179.61it/s]\n",
      "Epoch 1 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.0923\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9745\n",
      "100%|██████████| 122/122 [00:00<00:00, 180.16it/s]\n",
      "Epoch 2 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.0673\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9811\n",
      "100%|██████████| 122/122 [00:00<00:00, 183.00it/s]\n",
      "Epoch 3 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.0579\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9838\n",
      "-- >> End of training phase << --\n",
      "Training completed\n",
      "Computing accuracy on the whole test set\n",
      "-- >> Start of eval phase << --\n",
      "-- Starting eval on experience 0 (Task 0) from test stream --\n",
      "100%|██████████| 19/19 [00:00<00:00, 246.32it/s]\n",
      "> Eval on experience 0 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp000 = 7.4585\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.0000\n",
      "-- Starting eval on experience 1 (Task 0) from test stream --\n",
      "100%|██████████| 22/22 [00:00<00:00, 228.07it/s]\n",
      "> Eval on experience 1 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp001 = 5.6269\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.0000\n",
      "-- Starting eval on experience 2 (Task 0) from test stream --\n",
      "100%|██████████| 20/20 [00:00<00:00, 247.90it/s]\n",
      "> Eval on experience 2 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp002 = 6.8030\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp002 = 0.0010\n",
      "-- Starting eval on experience 3 (Task 0) from test stream --\n",
      "100%|██████████| 21/21 [00:00<00:00, 220.23it/s]\n",
      "> Eval on experience 3 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp003 = 6.7649\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp003 = 0.0490\n",
      "-- Starting eval on experience 4 (Task 0) from test stream --\n",
      "100%|██████████| 21/21 [00:00<00:00, 240.90it/s]\n",
      "> Eval on experience 4 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp004 = 0.0426\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp004 = 0.9871\n",
      "-- >> End of eval phase << --\n",
      "\tLoss_Stream/eval_phase/test_stream/Task000 = 5.3029\n",
      "\tTop1_Acc_Stream/eval_phase/test_stream/Task000 = 0.2085\n"
     ]
    }
   ],
   "source": [
    "from avalanche.benchmarks.classic import SplitMNIST\n",
    "\n",
    "# scenario\n",
    "benchmark = SplitMNIST(n_experiences=5, seed=1)\n",
    "\n",
    "# TRAINING LOOP\n",
    "print('Starting experiment...')\n",
    "results = []\n",
    "for experience in benchmark.train_stream:\n",
    "    print(\"Start of experience: \", experience.current_experience)\n",
    "    print(\"Current Classes: \", experience.classes_in_this_experience)\n",
    "\n",
    "    cl_strategy.train(experience)\n",
    "    print('Training completed')\n",
    "\n",
    "    print('Computing accuracy on the whole test set')\n",
    "    results.append(cl_strategy.eval(benchmark.test_stream))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Adding Plugins\n",
    "\n",
    "Most continual learning strategies follow roughly the same training/evaluation loops, i.e. a simple naive strategy (a.k.a. finetuning) augmented with additional behavior to counteract catastrophic forgetting. The plugin systems in Avalanche is designed to easily augment continual learning strategies with custom behavior, without having to rewrite the training loop from scratch. Avalanche strategies accept an optional list of `plugins` that will be executed during the training/evaluation loops.\n",
    "\n",
    "For example, early stopping is implemented as a plugin:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from avalanche.training.plugins import EarlyStoppingPlugin\n",
    "\n",
    "strategy = Naive(\n",
    "    model, optimizer, criterion,\n",
    "    plugins=[EarlyStoppingPlugin(patience=10, val_stream_name='train')])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "In Avalanche, most continual learning strategies are implemented using plugins, which makes it easy to combine them together. For example, it is extremely easy to create a hybrid strategy that combines replay and EWC together by passing the appropriate `plugins` list to the `SupervisedTemplate`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from avalanche.training.templates import SupervisedTemplate\n",
    "from avalanche.training.plugins import ReplayPlugin, EWCPlugin\n",
    "\n",
    "replay = ReplayPlugin(mem_size=100)\n",
    "ewc = EWCPlugin(ewc_lambda=0.001)\n",
    "strategy = SupervisedTemplate(\n",
    "    model, optimizer, criterion,\n",
    "    plugins=[replay, ewc])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Beware that most strategy plugins modify the internal state. As a result, not all the strategy plugins can be combined together. For example, it does not make sense to use multiple replay plugins since they will try to modify the same strategy variables (mini-batches, dataloaders), and therefore they will be in conflict."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 📝 A Look Inside Avalanche Strategies\n",
    "\n",
    "If you arrived at this point you already know how to use Avalanche strategies and are ready to use it. However, before making your own strategies you need to understand a little bit the internal implementation of the training and evaluation loops.\n",
    "\n",
    "In _Avalanche_ you can customize a strategy in 2 ways:\n",
    "\n",
    "1. **Plugins**: Most strategies can be implemented as additional code that runs on top of the basic training and evaluation loops (e.g. the `Naive` strategy). Therefore, the easiest way to define a custom strategy such as a regularization or replay strategy, is to define it as a custom plugin. The advantage of plugins is that they can be combined, as long as they are compatible, i.e. they do not modify the same part of the state. The disadvantage is that in order to do so you need to understand the strategy loop, which can be a bit complex at first.\n",
    "2. **Subclassing**: In _Avalanche_, continual learning strategies inherit from the appropriate template, which provides generic training and evaluation loops. The most high level template is the `BaseTemplate`, from which all the _Avalanche_'s strategies inherit. Most template's methods can be safely overridden (with some caveats that we will see later).\n",
    "\n",
    "Keep in mind that if you already have a working continual learning strategy that does not use _Avalanche_, you can use most Avalanche components such as `benchmarks`, `evaluation`, and `models` without using _Avalanche_'s strategies!\n",
    "\n",
    "### Training and Evaluation Loops\n",
    "\n",
    "As we already mentioned, _Avalanche_ strategies inherit from the appropriate template (e.g. continual supervised learning strategies inherit from the `SupervisedTemplate`). These templates provide:\n",
    "\n",
    "1. Basic _Training_ and _Evaluation_ loops which define a naive (finetuning) strategy.\n",
    "2. _Callback_ points, which are used to call the plugins at a specific moments during the loop's execution.\n",
    "3. A set of variables representing the state of the loops (current model, data, mini-batch, predictions, ...) which allows plugins and child classes to easily manipulate the state of the training loop.\n",
    "\n",
    "The training loop has the following structure:\n",
    "```text\n",
    "train\n",
    "    before_training\n",
    "\n",
    "    before_train_dataset_adaptation\n",
    "    train_dataset_adaptation\n",
    "    after_train_dataset_adaptation\n",
    "    make_train_dataloader\n",
    "    model_adaptation\n",
    "    make_optimizer\n",
    "    before_training_exp  # for each exp\n",
    "        before_training_epoch  # for each epoch\n",
    "            before_training_iteration  # for each iteration\n",
    "                before_forward\n",
    "                after_forward\n",
    "                before_backward\n",
    "                after_backward\n",
    "            after_training_iteration\n",
    "            before_update\n",
    "            after_update\n",
    "        after_training_epoch\n",
    "    after_training_exp\n",
    "    after_training\n",
    "```\n",
    "\n",
    "The evaluation loop is similar:\n",
    "```text\n",
    "eval\n",
    "    before_eval\n",
    "    before_eval_dataset_adaptation\n",
    "    eval_dataset_adaptation\n",
    "    after_eval_dataset_adaptation\n",
    "    make_eval_dataloader\n",
    "    model_adaptation\n",
    "    before_eval_exp  # for each exp\n",
    "        eval_epoch  # we have a single epoch in evaluation mode\n",
    "            before_eval_iteration  # for each iteration\n",
    "                before_eval_forward\n",
    "                after_eval_forward\n",
    "            after_eval_iteration\n",
    "    after_eval_exp\n",
    "    after_eval\n",
    "```\n",
    "\n",
    "Methods starting with `before/after` are the methods responsible for calling the plugins.\n",
    "Notice that before the start of each experience during training we have several phases:\n",
    "- *dataset adaptation*: This is the phase where the training data can be modified by the strategy, for example by adding other samples from a separate buffer.\n",
    "- *dataloader initialization*: Initialize the data loader. Many strategies (e.g. replay) use custom dataloaders to balance the data.\n",
    "- *model adaptation*: Here, the dynamic models (see the `models` tutorial) are updated by calling their `adaptation` method.\n",
    "- *optimizer initialization*: After the model has been updated, the optimizer should also be updated to ensure that the new parameters are optimized.\n",
    "\n",
    "### Strategy State\n",
    "The strategy state is accessible via several attributes. Most of these can be modified by plugins and subclasses:\n",
    "- `self.clock`: keeps track of several event counters.\n",
    "- `self.experience`: the current experience.\n",
    "- `self.adapted_dataset`: the data modified by the dataset adaptation phase.\n",
    "- `self.dataloader`: the current dataloader.\n",
    "- `self.mbatch`: the current mini-batch. For supervised classification problems, mini-batches have the form `<x, y, t>`, where `x` is the input, `y` is the target class, and `t` is the task label.\n",
    "- `self.mb_output`: the current model's output.\n",
    "- `self.loss`: the current loss.\n",
    "- `self.is_training`: `True` if the strategy is in training mode.\n",
    "\n",
    "## How to Write a Plugin\n",
    "Plugins provide a simple solution to define a new strategy by augmenting the behavior of another strategy (typically the `Naive` strategy). This approach reduces the overhead and code duplication, **improving code readability and prototyping speed**.\n",
    "\n",
    "Creating a plugin is straightforward. As with strategies, you must create a class that inherits from the corresponding plugin template (`BasePlugin`, `BaseSGDPlugin`, `SupervisedPlugin`) and implements the callbacks that you need. The exact callback to use depend on the aim of your plugin. You can use the loop shown above to understand what callbacks you need to use. For example, we show below a simple replay plugin that uses `after_training_exp` to update the buffer after each training experience, and the `before_training_exp` to customize the dataloader. Notice that `before_training_exp` is executed after `make_train_dataloader`, which means that the `Naive` strategy already updated the dataloader. If we used another callback, such as `before_train_dataset_adaptation`, our dataloader would have been overwritten by the `Naive` strategy. Plugin methods always receive the `strategy` as an argument, so they can access and modify the strategy's state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-- >> Start of training phase << --\n",
      "100%|██████████| 89/89 [00:00<00:00, 160.77it/s]\n",
      "Epoch 0 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.1742\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9481\n",
      "100%|██████████| 100/100 [00:00<00:00, 141.55it/s]\n",
      "Epoch 0 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.5905\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8932\n",
      "100%|██████████| 92/92 [00:00<00:00, 139.20it/s]\n",
      "Epoch 0 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.5499\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8565\n",
      "100%|██████████| 95/95 [00:00<00:00, 141.14it/s]\n",
      "Epoch 0 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.6958\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8176\n",
      "100%|██████████| 95/95 [00:00<00:00, 138.03it/s]\n",
      "Epoch 0 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.7430\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8079\n",
      "-- >> End of training phase << --\n",
      "-- >> Start of eval phase << --\n",
      "-- Starting eval on experience 0 (Task 0) from test stream --\n",
      "100%|██████████| 15/15 [00:00<00:00, 195.45it/s]\n",
      "> Eval on experience 0 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp000 = 0.6133\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.8027\n",
      "-- Starting eval on experience 1 (Task 0) from test stream --\n",
      "100%|██████████| 17/17 [00:00<00:00, 216.84it/s]\n",
      "> Eval on experience 1 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp001 = 0.3809\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.8976\n",
      "-- Starting eval on experience 2 (Task 0) from test stream --\n",
      "100%|██████████| 16/16 [00:00<00:00, 222.21it/s]\n",
      "> Eval on experience 2 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp002 = 0.7461\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp002 = 0.8076\n",
      "-- Starting eval on experience 3 (Task 0) from test stream --\n",
      "100%|██████████| 16/16 [00:00<00:00, 186.18it/s]\n",
      "> Eval on experience 3 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp003 = 1.1641\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp003 = 0.4804\n",
      "-- Starting eval on experience 4 (Task 0) from test stream --\n",
      "100%|██████████| 16/16 [00:00<00:00, 210.76it/s]\n",
      "> Eval on experience 4 (Task 0) from test stream ended.\n",
      "\tLoss_Exp/eval_phase/test_stream/Task000/Exp004 = 0.0796\n",
      "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp004 = 0.9786\n",
      "-- >> End of eval phase << --\n",
      "\tLoss_Stream/eval_phase/test_stream/Task000 = 0.5928\n",
      "\tTop1_Acc_Stream/eval_phase/test_stream/Task000 = 0.7945\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'Top1_Acc_Epoch/train_phase/train_stream/Task000': 0.8079439561165819,\n",
       " 'Loss_Epoch/train_phase/train_stream/Task000': 0.7429647520051568,\n",
       " 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.8027027027027027,\n",
       " 'Loss_Exp/eval_phase/test_stream/Task000/Exp000': 0.613303724881765,\n",
       " 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.8975542224273189,\n",
       " 'Loss_Exp/eval_phase/test_stream/Task000/Exp001': 0.3809299925587174,\n",
       " 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 0.8075742067553736,\n",
       " 'Loss_Exp/eval_phase/test_stream/Task000/Exp002': 0.7460512161986976,\n",
       " 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003': 0.4804358593363051,\n",
       " 'Loss_Exp/eval_phase/test_stream/Task000/Exp003': 1.1640707542100361,\n",
       " 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004': 0.9786069651741294,\n",
       " 'Loss_Exp/eval_phase/test_stream/Task000/Exp004': 0.0795803559137814,\n",
       " 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.7945,\n",
       " 'Loss_Stream/eval_phase/test_stream/Task000': 0.5928086629495025}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from avalanche.benchmarks.utils.data_loader import ReplayDataLoader\n",
    "from avalanche.core import SupervisedPlugin\n",
    "from avalanche.training.storage_policy import ReservoirSamplingBuffer\n",
    "\n",
    "\n",
    "class ReplayP(SupervisedPlugin):\n",
    "\n",
    "    def __init__(self, mem_size):\n",
    "        \"\"\" A simple replay plugin with reservoir sampling. \"\"\"\n",
    "        super().__init__()\n",
    "        self.buffer = ReservoirSamplingBuffer(max_size=mem_size)\n",
    "\n",
    "    def before_training_exp(self, strategy: \"SupervisedTemplate\",\n",
    "                            num_workers: int = 0, shuffle: bool = True,\n",
    "                            **kwargs):\n",
    "        \"\"\" Use a custom dataloader to combine samples from the current data and memory buffer. \"\"\"\n",
    "        if len(self.buffer.buffer) == 0:\n",
    "            # first experience. We don't use the buffer, no need to change\n",
    "            # the dataloader.\n",
    "            return\n",
    "        strategy.dataloader = ReplayDataLoader(\n",
    "            strategy.adapted_dataset,\n",
    "            self.buffer.buffer,\n",
    "            oversample_small_tasks=True,\n",
    "            num_workers=num_workers,\n",
    "            batch_size=strategy.train_mb_size,\n",
    "            shuffle=shuffle)\n",
    "\n",
    "    def after_training_exp(self, strategy: \"SupervisedTemplate\", **kwargs):\n",
    "        \"\"\" Update the buffer. \"\"\"\n",
    "        self.buffer.update(strategy, **kwargs)\n",
    "\n",
    "\n",
    "benchmark = SplitMNIST(n_experiences=5, seed=1)\n",
    "model = SimpleMLP(num_classes=10)\n",
    "optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)\n",
    "criterion = CrossEntropyLoss()\n",
    "strategy = Naive(model=model, optimizer=optimizer, criterion=criterion, train_mb_size=128,\n",
    "                 plugins=[ReplayP(mem_size=2000)])\n",
    "strategy.train(benchmark.train_stream)\n",
    "strategy.eval(benchmark.test_stream)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The animation below shows the execution and callbacks steps of a Naive strategy that is extended with the EWC plugin:\n",
    "\n",
    "<img align='center' style='max-width: 800px' src='../../.gitbook/assets/ewc_template_animation.gif'>\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Check base plugin's documentation for a complete list of the available callbacks."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## How to Write a Custom Strategy\n",
    "\n",
    "You can always define a custom strategy by overriding the corresponding template methods.\n",
    "However, There is an important caveat to keep in mind. If you override a method, you must remember to call all the callback's handlers (the methods starting with `before/after`) at the appropriate points. For example, `train` calls `before_training` and `after_training` before and after the training loops, respectively. The easiest way to avoid mistakes is to start from the template's method that you want to override and modify it based on your own needs without removing the callbacks handling.\n",
    "\n",
    "Notice that the `EvaluationPlugin` (see `evaluation` tutorial) uses the strategy callbacks.\n",
    "\n",
    "As an example, the `SupervisedTemplate`, for continual supervised strategies, provides the global state of the loop in the strategy's attributes, which you can safely use when you override a method. For instance, the `Cumulative` strategy trains a model continually on the union of all the experiences encountered so far. To achieve this, the cumulative strategy overrides `adapt_train_dataset` and updates `self.adapted_dataset' by concatenating all the previous experiences with the current one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-- >> Start of training phase << --\n",
      "100%|██████████| 89/89 [00:00<00:00, 159.03it/s]\n",
      "Epoch 0 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.1343\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9636\n",
      "100%|██████████| 188/188 [00:01<00:00, 160.31it/s]\n",
      "Epoch 0 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.1294\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9639\n",
      "100%|██████████| 280/280 [00:01<00:00, 160.79it/s]\n",
      "Epoch 0 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.1963\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9433\n",
      "100%|██████████| 375/375 [00:02<00:00, 157.47it/s]\n",
      "Epoch 0 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.2452\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9294\n",
      "100%|██████████| 469/469 [00:02<00:00, 161.52it/s]\n",
      "Epoch 0 ended.\n",
      "\tLoss_Epoch/train_phase/train_stream/Task000 = 0.2460\n",
      "\tTop1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9285\n",
      "-- >> End of training phase << --\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'Top1_Acc_Epoch/train_phase/train_stream/Task000': 0.92855,\n",
       " 'Loss_Epoch/train_phase/train_stream/Task000': 0.2460384528319041}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from avalanche.benchmarks.utils import concat_datasets\n",
    "from avalanche.training.templates import SupervisedTemplate\n",
    "\n",
    "\n",
    "class Cumulative(SupervisedTemplate):\n",
    "    def __init__(self, *args, **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "        self.dataset = None  # cumulative dataset\n",
    "\n",
    "    def train_dataset_adaptation(self, **kwargs):\n",
    "        super().train_dataset_adaptation(**kwargs)\n",
    "        curr_data = self.experience.dataset\n",
    "        if self.dataset is None:\n",
    "            self.dataset = curr_data\n",
    "        else:\n",
    "            self.dataset = concat_datasets([self.dataset, curr_data])\n",
    "        self.adapted_dataset = self.dataset.train()\n",
    "\n",
    "strategy = Cumulative(model=model, optimizer=optimizer, criterion=criterion, train_mb_size=128)\n",
    "strategy.train(benchmark.train_stream)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Easy, isn't it? :-\\)\n",
    "\n",
    "In general, we recommend to _implement a Strategy via plugins_, if possible. This approach is the easiest to use and requires minimal knowledge of the strategy templates. It also allows other people to re-use your plugin and facilitates interoperability among different strategies.\n",
    "\n",
    "For example, replay strategies can be implemented as a custom strategy or as plugins. However, creating a plugin allows using the replay in conjunction with other strategies or plugins, making it possible to combine different approaches to build the ultimate continual learning algorithm!"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "This completes the \"_Training_\" chapter for the \"_From Zero to Hero_\" series. We hope you enjoyed it!\n",
    "\n",
    "## 🤝 Run it on Google Colab\n",
    "\n",
    "You can run _this chapter_ and play with it on Google Colaboratory: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ContinualAI/avalanche/blob/master/notebooks/from-zero-to-hero-tutorial/04_training.ipynb)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
